{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "aaf8ff51",
   "metadata": {},
   "source": [
    "# Build a RAG chain by generating embeddings for NVIDIA Triton documentation\n",
    "\n",
    "In this notebook we demonstrate how to build a RAG using [NVIDIA AI Endpoints for LangChain](https://python.langchain.com/docs/integrations/text_embedding/nvidia_ai_endpoints). We create a vector store by downloading web pages and generating their embeddings using FAISS. We then showcase two different chat chains for querying the vector store. For this example, we use the NVIDIA Triton documentation website, though the code can be easily modified to use any other source.  \n",
    "\n",
    "### First stage is to load NVIDIA Triton documentation from the web, chunkify the data, and generate embeddings using FAISS\n",
    "\n",
    "To get started:\n",
    "\n",
    "1. Create a free account with the NVIDIA NGC service, which hosts AI solution catalogs, containers, models, etc.\n",
    "\n",
    "2. Navigate to Catalog > AI Foundation Models > (Model with API endpoint).\n",
    "\n",
    "3. Select the API option and click Generate Key.\n",
    "\n",
    "4. Save the generated key as NVIDIA_API_KEY. From there, you should have access to the endpoints."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e55d2d53",
   "metadata": {},
   "source": [
    "First install prerequisite libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fd4dcc8b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
      "Requirement already satisfied: langchain in /usr/local/lib/python3.10/dist-packages (0.1.9)\n",
      "Requirement already satisfied: PyYAML>=5.3 in /usr/local/lib/python3.10/dist-packages (from langchain) (6.0)\n",
      "Requirement already satisfied: SQLAlchemy<3,>=1.4 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.0.29)\n",
      "Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /usr/local/lib/python3.10/dist-packages (from langchain) (3.9.3)\n",
      "Requirement already satisfied: async-timeout<5.0.0,>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (4.0.2)\n",
      "Requirement already satisfied: dataclasses-json<0.7,>=0.5.7 in /usr/local/lib/python3.10/dist-packages (from langchain) (0.6.4)\n",
      "Requirement already satisfied: jsonpatch<2.0,>=1.33 in /usr/local/lib/python3.10/dist-packages (from langchain) (1.33)\n",
      "Requirement already satisfied: langchain-community<0.1,>=0.0.21 in /usr/local/lib/python3.10/dist-packages (from langchain) (0.0.26)\n",
      "Requirement already satisfied: langchain-core<0.2,>=0.1.26 in /usr/local/lib/python3.10/dist-packages (from langchain) (0.1.29)\n",
      "Requirement already satisfied: langsmith<0.2.0,>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (0.1.39)\n",
      "Requirement already satisfied: numpy<2,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain) (1.22.2)\n",
      "Requirement already satisfied: pydantic<3,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain) (1.10.7)\n",
      "Requirement already satisfied: requests<3,>=2 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.31.0)\n",
      "Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (8.2.3)\n",
      "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.1)\n",
      "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (23.1.0)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.3)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (6.0.4)\n",
      "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.9.2)\n",
      "Requirement already satisfied: marshmallow<4.0.0,>=3.18.0 in /usr/local/lib/python3.10/dist-packages (from dataclasses-json<0.7,>=0.5.7->langchain) (3.21.1)\n",
      "Requirement already satisfied: typing-inspect<1,>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from dataclasses-json<0.7,>=0.5.7->langchain) (0.9.0)\n",
      "Requirement already satisfied: jsonpointer>=1.9 in /usr/local/lib/python3.10/dist-packages (from jsonpatch<2.0,>=1.33->langchain) (2.4)\n",
      "Requirement already satisfied: anyio<5,>=3 in /usr/local/lib/python3.10/dist-packages (from langchain-core<0.2,>=0.1.26->langchain) (3.7.1)\n",
      "Requirement already satisfied: packaging<24.0,>=23.2 in /usr/local/lib/python3.10/dist-packages (from langchain-core<0.2,>=0.1.26->langchain) (23.2)\n",
      "Requirement already satisfied: orjson<4.0.0,>=3.9.14 in /usr/local/lib/python3.10/dist-packages (from langsmith<0.2.0,>=0.1.0->langchain) (3.10.0)\n",
      "Requirement already satisfied: typing-extensions>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain) (4.10.0)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (3.1.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (3.4)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (2.2.1)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (2022.12.7)\n",
      "Requirement already satisfied: greenlet!=0.4.17 in /usr/local/lib/python3.10/dist-packages (from SQLAlchemy<3,>=1.4->langchain) (3.0.3)\n",
      "Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3->langchain-core<0.2,>=0.1.26->langchain) (1.3.1)\n",
      "Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3->langchain-core<0.2,>=0.1.26->langchain) (1.1.1)\n",
      "Requirement already satisfied: mypy-extensions>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7,>=0.5.7->langchain) (1.0.0)\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0mLooking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
      "Requirement already satisfied: langchain_nvidia_ai_endpoints in /usr/local/lib/python3.10/dist-packages (0.0.4)\n",
      "Requirement already satisfied: aiohttp<4.0.0,>=3.9.1 in /usr/local/lib/python3.10/dist-packages (from langchain_nvidia_ai_endpoints) (3.9.3)\n",
      "Requirement already satisfied: langchain-core<0.2.0,>=0.1.5 in /usr/local/lib/python3.10/dist-packages (from langchain_nvidia_ai_endpoints) (0.1.29)\n",
      "Requirement already satisfied: pillow<11.0.0,>=10.0.0 in /usr/local/lib/python3.10/dist-packages (from langchain_nvidia_ai_endpoints) (10.3.0)\n",
      "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.9.1->langchain_nvidia_ai_endpoints) (1.3.1)\n",
      "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.9.1->langchain_nvidia_ai_endpoints) (23.1.0)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.9.1->langchain_nvidia_ai_endpoints) (1.3.3)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.9.1->langchain_nvidia_ai_endpoints) (6.0.4)\n",
      "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.9.1->langchain_nvidia_ai_endpoints) (1.9.2)\n",
      "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.9.1->langchain_nvidia_ai_endpoints) (4.0.2)\n",
      "Requirement already satisfied: PyYAML>=5.3 in /usr/local/lib/python3.10/dist-packages (from langchain-core<0.2.0,>=0.1.5->langchain_nvidia_ai_endpoints) (6.0)\n",
      "Requirement already satisfied: anyio<5,>=3 in /usr/local/lib/python3.10/dist-packages (from langchain-core<0.2.0,>=0.1.5->langchain_nvidia_ai_endpoints) (3.7.1)\n",
      "Requirement already satisfied: jsonpatch<2.0,>=1.33 in /usr/local/lib/python3.10/dist-packages (from langchain-core<0.2.0,>=0.1.5->langchain_nvidia_ai_endpoints) (1.33)\n",
      "Requirement already satisfied: langsmith<0.2.0,>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from langchain-core<0.2.0,>=0.1.5->langchain_nvidia_ai_endpoints) (0.1.39)\n",
      "Requirement already satisfied: packaging<24.0,>=23.2 in /usr/local/lib/python3.10/dist-packages (from langchain-core<0.2.0,>=0.1.5->langchain_nvidia_ai_endpoints) (23.2)\n",
      "Requirement already satisfied: pydantic<3,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain-core<0.2.0,>=0.1.5->langchain_nvidia_ai_endpoints) (1.10.7)\n",
      "Requirement already satisfied: requests<3,>=2 in /usr/local/lib/python3.10/dist-packages (from langchain-core<0.2.0,>=0.1.5->langchain_nvidia_ai_endpoints) (2.31.0)\n",
      "Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /usr/local/lib/python3.10/dist-packages (from langchain-core<0.2.0,>=0.1.5->langchain_nvidia_ai_endpoints) (8.2.3)\n",
      "Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3->langchain-core<0.2.0,>=0.1.5->langchain_nvidia_ai_endpoints) (3.4)\n",
      "Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3->langchain-core<0.2.0,>=0.1.5->langchain_nvidia_ai_endpoints) (1.3.1)\n",
      "Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3->langchain-core<0.2.0,>=0.1.5->langchain_nvidia_ai_endpoints) (1.1.1)\n",
      "Requirement already satisfied: jsonpointer>=1.9 in /usr/local/lib/python3.10/dist-packages (from jsonpatch<2.0,>=1.33->langchain-core<0.2.0,>=0.1.5->langchain_nvidia_ai_endpoints) (2.4)\n",
      "Requirement already satisfied: orjson<4.0.0,>=3.9.14 in /usr/local/lib/python3.10/dist-packages (from langsmith<0.2.0,>=0.1.0->langchain-core<0.2.0,>=0.1.5->langchain_nvidia_ai_endpoints) (3.10.0)\n",
      "Requirement already satisfied: typing-extensions>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain-core<0.2.0,>=0.1.5->langchain_nvidia_ai_endpoints) (4.10.0)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain-core<0.2.0,>=0.1.5->langchain_nvidia_ai_endpoints) (3.1.0)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain-core<0.2.0,>=0.1.5->langchain_nvidia_ai_endpoints) (2.2.1)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain-core<0.2.0,>=0.1.5->langchain_nvidia_ai_endpoints) (2022.12.7)\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0mLooking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
      "Collecting faiss-cpu\n",
      "  Downloading faiss_cpu-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.6 kB)\n",
      "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from faiss-cpu) (1.22.2)\n",
      "Downloading faiss_cpu-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (27.0 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m27.0/27.0 MB\u001b[0m \u001b[31m16.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hInstalling collected packages: faiss-cpu\n",
      "Successfully installed faiss-cpu-1.8.0\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "!pip install langchain\n",
    "!pip install langchain_nvidia_ai_endpoints\n",
    "!pip install faiss-cpu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "980506c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from langchain.chains import ConversationalRetrievalChain, LLMChain\n",
    "from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT\n",
    "from langchain.chains.question_answering import load_qa_chain\n",
    "from langchain.memory import ConversationBufferMemory\n",
    "from langchain.vectorstores import FAISS\n",
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "from langchain_nvidia_ai_endpoints import ChatNVIDIA\n",
    "from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "025de714",
   "metadata": {},
   "source": [
    "Set up API key"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bf9a84ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "Enter your NVIDIA API key:  ······································································\n"
     ]
    }
   ],
   "source": [
    "import getpass\n",
    "\n",
    "if not os.environ.get(\"NVIDIA_API_KEY\", \"\").startswith(\"nvapi-\"):\n",
    "    nvapi_key = getpass.getpass(\"Enter your NVIDIA API key: \")\n",
    "    assert nvapi_key.startswith(\"nvapi-\"), f\"{nvapi_key[:5]}... is not a valid key\"\n",
    "    os.environ[\"NVIDIA_API_KEY\"] = nvapi_key"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91fcd102",
   "metadata": {},
   "source": [
    "Helper functions for loading html files, which we'll use to generate the embeddings. We'll use this later to load the relevant html documents from the Triton documentation website and convert to a vector store."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d84c5ef5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from typing import List, Union\n",
    "\n",
    "import requests\n",
    "from bs4 import BeautifulSoup\n",
    "\n",
    "def html_document_loader(url: Union[str, bytes]) -> str:\n",
    "    \"\"\"\n",
    "    Loads the HTML content of a document from a given URL and return it's content.\n",
    "\n",
    "    Args:\n",
    "        url: The URL of the document.\n",
    "\n",
    "    Returns:\n",
    "        The content of the document.\n",
    "\n",
    "    Raises:\n",
    "        Exception: If there is an error while making the HTTP request.\n",
    "\n",
    "    \"\"\"\n",
    "    try:\n",
    "        response = requests.get(url)\n",
    "        html_content = response.text\n",
    "    except Exception as e:\n",
    "        print(f\"Failed to load {url} due to exception {e}\")\n",
    "        return \"\"\n",
    "\n",
    "    try:\n",
    "        # Create a Beautiful Soup object to parse html\n",
    "        soup = BeautifulSoup(html_content, \"html.parser\")\n",
    "\n",
    "        # Remove script and style tags\n",
    "        for script in soup([\"script\", \"style\"]):\n",
    "            script.extract()\n",
    "\n",
    "        # Get the plain text from the HTML document\n",
    "        text = soup.get_text()\n",
    "\n",
    "        # Remove excess whitespace and newlines\n",
    "        text = re.sub(\"\\s+\", \" \", text).strip()\n",
    "\n",
    "        return text\n",
    "    except Exception as e:\n",
    "        print(f\"Exception {e} while loading document\")\n",
    "        return \"\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad3d3f0c",
   "metadata": {},
   "source": [
    "Read html files and split text in preparation for embedding generation\n",
    "Note chunk_size value must match the specific LLM used for embedding genetation\n",
    "\n",
    "Make sure to pay attention to the chunk_size parameter in TextSplitter. Setting the right chunk size is critical for RAG performance, as much of a RAG’s success is based on the retrieval step finding the right context for generation. The entire prompt (retrieved chunks + user query) must fit within the LLM’s context window. Therefore, you should not specify chunk sizes too big, and balance them out with the estimated query size. For example, while OpenAI LLMs have a context window of 8k-32k tokens, Llama2 is limited to 4k tokens. Experiment with different chunk sizes, but typical values should be 100-600, depending on the LLM."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6f48635f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_embeddings(embedding_path: str = \"./embed\"):\n",
    "\n",
    "    embedding_path = \"./embed\"\n",
    "    print(f\"Storing embeddings to {embedding_path}\")\n",
    "\n",
    "    # List of web pages containing NVIDIA Triton technical documentation\n",
    "    urls = [\n",
    "         \"https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html\",\n",
    "         \"https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/getting_started/quickstart.html\",\n",
    "         \"https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_repository.html\",\n",
    "         \"https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_analyzer.html\",\n",
    "         \"https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/architecture.html\",\n",
    "    ]\n",
    "\n",
    "    documents = []\n",
    "    for url in urls:\n",
    "        document = html_document_loader(url)\n",
    "        documents.append(document)\n",
    "\n",
    "\n",
    "    text_splitter = RecursiveCharacterTextSplitter(\n",
    "        chunk_size=1000,\n",
    "        chunk_overlap=0,\n",
    "        length_function=len,\n",
    "    )\n",
    "    texts = text_splitter.create_documents(documents)\n",
    "    index_docs(url, text_splitter, texts, embedding_path)\n",
    "    print(\"Generated embedding successfully\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "942934e8",
   "metadata": {},
   "source": [
    "Generate embeddings using NVIDIA AI Endpoints for LangChain and save embeddings to offline vector store in the /embed directory for future re-use"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "27d1aced",
   "metadata": {},
   "outputs": [],
   "source": [
    "def index_docs(url: Union[str, bytes], splitter, documents: List[str], dest_embed_dir) -> None:\n",
    "    \"\"\"\n",
    "    Split the document into chunks and create embeddings for the document\n",
    "\n",
    "    Args:\n",
    "        url: Source url for the document.\n",
    "        splitter: Splitter used to split the document\n",
    "        documents: list of documents whose embeddings needs to be created\n",
    "        dest_embed_dir: destination directory for embeddings\n",
    "\n",
    "    Returns:\n",
    "        None\n",
    "    \"\"\"\n",
    "    embeddings = NVIDIAEmbeddings(model=\"nvolveqa_40k\")\n",
    "    \n",
    "    for document in documents:\n",
    "        texts = splitter.split_text(document.page_content)\n",
    "\n",
    "        # metadata to attach to document\n",
    "        metadatas = [document.metadata]\n",
    "\n",
    "        # create embeddings and add to vector store\n",
    "        if os.path.exists(dest_embed_dir):\n",
    "            update = FAISS.load_local(folder_path=dest_embed_dir, embeddings=embeddings)\n",
    "            update.add_texts(texts, metadatas=metadatas)\n",
    "            update.save_local(folder_path=dest_embed_dir)\n",
    "        else:\n",
    "            docsearch = FAISS.from_texts(texts, embedding=embeddings, metadatas=metadatas)\n",
    "            docsearch.save_local(folder_path=dest_embed_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9831f7ba",
   "metadata": {},
   "source": [
    "### Second stage is to load the embeddings from the vector store and build a RAG using NVIDIAEmbeddings\n",
    "\n",
    "Create the embeddings model using NVIDIA Retrieval QA Embedding endpoint. This model represents words, phrases, or other entities as vectors of numbers and understands the relation between words and phrases. See here for reference: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/ai-foundation/models/nvolve-40k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f56cadd0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Storing embeddings to ./embed\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "create_embeddings()\n",
    "\n",
    "embedding_model = NVIDIAEmbeddings(model=\"nvolveqa_40k\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e211270",
   "metadata": {},
   "source": [
    "Load documents from vector database using FAISS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "648b9d2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Embed documents\n",
    "embedding_path = \"embed/\"\n",
    "docsearch = FAISS.load_local(folder_path=embedding_path, embeddings=embedding_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01153bc4",
   "metadata": {},
   "source": [
    "Create a ConversationalRetrievalChain chain using NeMoLLM. In this chain we demonstrate the use of 2 LLMs: one for summarization and another for chat. This improves the overall result in more complicated scenarios. We'll use Llama2 70B for the first LLM and Mixtral for the Chat element in the chain. We add a question_generator to generate relevant query prompt. See here for reference: https://python.langchain.com/docs/modules/chains/popular/chat_vector_db#conversationalretrievalchain-with-streaming-to-stdout"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e460822",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = ChatNVIDIA(model=\"llama2_70b\")\n",
    "\n",
    "memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)\n",
    "\n",
    "question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)\n",
    "\n",
    "chat = ChatNVIDIA(model=\"mixtral_8x7b\", temperature=0.1, max_tokens=1000, top_p=1.0)\n",
    "\n",
    "doc_chain = load_qa_chain(chat , chain_type=\"stuff\", prompt=QA_PROMPT)\n",
    "\n",
    "qa = ConversationalRetrievalChain(\n",
    "    retriever=docsearch.as_retriever(),\n",
    "    combine_docs_chain=doc_chain,\n",
    "    memory=memory,\n",
    "    question_generator=question_generator,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2749482",
   "metadata": {},
   "source": [
    "Ask any question about Triton"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f5ead62",
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"What is Triton?\"\n",
    "result = qa({\"question\": query})\n",
    "print(result.get(\"answer\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c1d7dd9",
   "metadata": {},
   "source": [
    "Ask another question about Triton"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5e80a22",
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"What interfaces does Triton support?\"\n",
    "result = qa({\"question\": query})\n",
    "print(result.get(\"answer\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd01e957",
   "metadata": {},
   "source": [
    "Finally showcase chat capabilites by asking a question about the previous query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a222b8e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"But why?\"\n",
    "result = qa({\"question\": query})\n",
    "print(result.get(\"answer\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d231b9df",
   "metadata": {},
   "source": [
    "Now we demonstrate a simpler chain using a single LLM only, a chat LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2a2f90c",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = ChatNVIDIA(model=\"llama2_70b\", temperature=0.1, max_tokens=1000, top_p=1.0)\n",
    "\n",
    "qa_prompt=QA_PROMPT\n",
    "\n",
    "doc_chain = load_qa_chain(llm, chain_type=\"stuff\", prompt=QA_PROMPT)\n",
    "\n",
    "qa = ConversationalRetrievalChain.from_llm(\n",
    "    llm=llm,\n",
    "    retriever=docsearch.as_retriever(),\n",
    "    chain_type=\"stuff\",\n",
    "    memory=memory,\n",
    "    combine_docs_chain_kwargs={'prompt': qa_prompt},\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7253f735",
   "metadata": {},
   "source": [
    "Now try asking a question about Triton with the simpler chain. Compare the answer to the result with previous complex chain model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b22dcbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"What is Triton?\"\n",
    "result = qa({\"question\": query})\n",
    "print(result.get(\"answer\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63fd45fc",
   "metadata": {},
   "source": [
    "Ask another question about Triton"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c81f2d55",
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"Does Triton support ONNX?\"\n",
    "result = qa({\"question\": query})\n",
    "print(result.get(\"answer\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58caaebb",
   "metadata": {},
   "source": [
    "Finally showcase chat capabilites by asking a question about the previous query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dea39f61",
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"But why?\"\n",
    "result = qa({\"question\": query})\n",
    "print(result.get(\"answer\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
